import cv2
import os
import sys
package_path = "/system/lib"
if package_path not in sys.path:
    sys.path.insert(0,package_path)

import numpy as np
import serial

def preprocess(frame):
    # Initialize array for single image
    processed_image = np.zeros((1, 3, 32, 32), dtype=np.float32)
    
    # Convert to grayscale
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    
    # Apply fixed thresholding for binarization
    _, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
    
    # Convert binary back to 3-channel
    binary_3ch = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
    
    # Resize 
    img = cv2.resize(src=binary_3ch, dsize=(32, 32), interpolation=cv2.INTER_LINEAR)
    # Convert BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # Normalize
    img = img.astype(np.float32)
    img -= 127.5
    img *= 1 / 127.5
    # Change from (H,W,C) to (C,H,W)
    img = img.transpose(2, 0, 1)
    processed_image[0] = img
    
    return processed_image

def postprocess(outputs):
    outputs = list(outputs.values())[0]
    pred_idx = outputs.argmax()
    return pred_idx

def calculate_bias(original_image):
    # Apply adaptive thresholding for better road detection
    adaptive_thresh = cv2.adaptiveThreshold(
        original_image,
        255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        11,
        2
    )
    
    # Apply morphological operations to clean up the image
    kernel = np.ones((5, 5), np.uint8)
    cleaned_image = cv2.morphologyEx(adaptive_thresh, cv2.MORPH_CLOSE, kernel)
    cleaned_image = cv2.morphologyEx(cleaned_image, cv2.MORPH_OPEN, kernel)
    
    height = cleaned_image.shape[0]
    width = cleaned_image.shape[1]
    mid_x = width // 2
    mid_y = height // 2
    
    black_pixels = np.where(cleaned_image > 0)
    y_coords = black_pixels[0]
    x_coords = black_pixels[1]
    
    if len(x_coords) > 0:
        points = np.column_stack((x_coords, y_coords))
        hull = cv2.convexHull(points)
        hull_points = hull.reshape(-1, 2)
        hull_x = hull_points[:, 0]
        hull_y = hull_points[:, 1]
        
        median_x = np.median(hull_x)
        left_mask = hull_x < median_x
        right_mask = hull_x >= median_x
        
        left_x = hull_x[left_mask]
        left_y = hull_y[left_mask]
        right_x = hull_x[right_mask]
        right_y = hull_y[right_mask]
        
        if len(left_y) > 1 and len(right_y) > 1:
            left_fit = np.polyfit(left_y, left_x, 1)
            right_fit = np.polyfit(right_y, right_x, 1)
            
            y1, y2 = 0, height
            left_x1 = int(left_fit[0] * y1 + left_fit[1])
            left_x2 = int(left_fit[0] * y2 + left_fit[1])
            right_x1 = int(right_fit[0] * y1 + right_fit[1])
            right_x2 = int(right_fit[0] * y2 + right_fit[1])
            
            mid_x1 = (left_x1 + right_x1) // 2
            mid_x2 = (left_x2 + right_x2) // 2
            
            bias = -100 * (mid_x2 - mid_x) / (width / 2)
            bias = np.clip(bias, -100, 100)
            
            return bias, cleaned_image
    return None, cleaned_image

def analyze_horizontal_line(cleaned_image, original_image):
    height, width = cleaned_image.shape
    white_pixels = np.where(cleaned_image > 0)
    y_coords = white_pixels[0]
    x_coords = white_pixels[1]
    
    if len(x_coords) > 0:
        highest_y = np.min(y_coords)
        lowest_y = np.max(y_coords)
        
        if (highest_y > 470 and lowest_y < 10):
            return 2
        
        mid_y = (highest_y + lowest_y) // 2
        dots_above_mid = np.sum(y_coords < mid_y)
        total_dots = len(y_coords)
        
        if dots_above_mid > total_dots / 2:
            return 1
        else:
            return 0
    return 2

def infer_from_camera():
    # Open the camera
    cap = cv2.VideoCapture(0)
    
    if not cap.isOpened():
        print("Error: Could not open camera.")
        return

    try:
        while True:
            # Capture frame-by-frame
            ret, frame = cap.read()
            if not ret:
                print("Error: Could not read frame.")
                break

            # Convert to grayscale
            gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            # Apply fixed thresholding
            _, binary = cv2.threshold(gray_frame, 127, 255, cv2.THRESH_BINARY_INV)
            
            # Clean up the image
            kernel = np.ones((5, 5), np.uint8)
            cleaned_image = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
            cleaned_image = cv2.morphologyEx(cleaned_image, cv2.MORPH_OPEN, kernel)
            
            # Get pred_idx from horizontal line analysis
            pred_idx = analyze_horizontal_line(cleaned_image, frame)
            
            # Calculate bias
            bias, annotated_frame = calculate_bias(gray_frame)

            if bias is not None:
                # Configure serial port
                port = '/dev/serial/by-id/usb-1a86_USB_Serial-if00-port0'
                baudrate = 115200
                ser = serial.Serial(port, baudrate, timeout=1)

                if pred_idx == 0:
                    TS_class = 0
                    class_text = "Start"
                elif pred_idx == 1:
                    TS_class = 1
                    class_text = "End"
                elif pred_idx == 2 and bias < 0:
                    TS_class = 2
                    class_text = "Left"
                elif pred_idx == 2 and bias >= 0:
                    TS_class = 3
                    class_text = "Right"
                else:
                    TS_class = 1
                    class_text = "End"

                TS_pos = int(abs(bias))
                TS_buffer = 'C' + chr(TS_class) + 'P' + chr(TS_pos)

                if 'ser' in locals() and ser.is_open:
                    data_to_send = TS_buffer.encode('utf-8')
                    if TS_pos <= 127 and TS_class <= 127:
                        ser.write(data_to_send)
                else:
                    print(f"无法打开串口 {port}")

            # Display the frames
            cv2.imshow('Original with Lines', annotated_frame)
            cv2.imshow('Binary with Horizontal Line', cleaned_image)

            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    finally:
        cap.release()
        cv2.destroyAllWindows()
        if 'ser' in locals() and ser.is_open:
            ser.close()

if __name__ == '__main__':
    infer_from_camera()
